Skip to content

Conversation

@VolodyaCO
Copy link
Collaborator

This PR adds an orthogonal layer given by Givens rotations, using the parallel algorithm described by Firas in https://arxiv.org/abs/2106.00003, which gives a forward complexity of O(n) and backward complexity of O(n log(n)), even though there are O(n^2) rotations.

This PR still is in draft. I wrote it for even n. Probably some more unit tests are to be done, but I am quite lazy (will do it after all math is checked for odd n).

@VolodyaCO VolodyaCO requested a review from kevinchern December 19, 2025 00:01
@VolodyaCO VolodyaCO self-assigned this Dec 19, 2025
@VolodyaCO VolodyaCO added the enhancement New feature or request label Dec 19, 2025
@VolodyaCO
Copy link
Collaborator Author

I somehow broke @kevinchern's tests, what the hell...

def test_store_config(self):
with self.subTest("Simple case"):

class MyModel(torch.nn.Module):
Copy link
Collaborator

@kevinchern kevinchern Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove formatting changes. Is this "black" formatting?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I have it by default on my vscode

@kevinchern
Copy link
Collaborator

kevinchern commented Dec 19, 2025

I somehow broke @kevinchern's tests, what the hell...

@VolodyaCO which tests? I'm seeing test_forward_agreement and test_backward_agreement failures on this CI test

@VolodyaCO
Copy link
Collaborator Author

I somehow broke @kevinchern's tests, what the hell...

@VolodyaCO which tests? I'm seeing test_forward_agreement and test_backward_agreement failures on this CI test

I forgot to update my tests to float64 precision. Now that I've done it, it's weird that all of the current failing tests are failing on

  File "/Users/distiller/project/tests/test_nn.py", line 144, in test_LinearBlock
    self.assertTrue(model_probably_good(model, (din,), (dout,)))

@kevinchern
Copy link
Collaborator

I somehow broke @kevinchern's tests, what the hell...

@VolodyaCO which tests? I'm seeing test_forward_agreement and test_backward_agreement failures on this CI test

I forgot to update my tests to float64 precision. Now that I've done it, it's weird that all of the current failing tests are failing on

  File "/Users/distiller/project/tests/test_nn.py", line 144, in test_LinearBlock
    self.assertTrue(model_probably_good(model, (din,), (dout,)))

Ahhhhhh. OK Theo also flagged this at #50 . It's a poorly-written test.. you can ignore it.

@VolodyaCO VolodyaCO marked this pull request as ready for review December 26, 2025 17:28
Returns:
list[list[tuple[int, int]]]: Blocks of edges for parallel Givens rotations.
Note:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where should I put this? in the release notes? or in the docstring itself?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simply change the Note: to

.. note::

    Lorem ipsum...

which would render a note box if we generate docs with Sphinx.

Comment on lines 81 to 85
angles (torch.Tensor): A ((n - 1) * n // 2,) shaped tensor containing all rotations
between pairs of dimensions.
blocks (torch.Tensor): A (n-1, n//2, 2) shaped tensor containing the indices that
specify rotations between pairs of dimensions. Each of the n-1 blocks contains n//2
pairs of independent rotations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code formatting?

Suggested change
angles (torch.Tensor): A ((n - 1) * n // 2,) shaped tensor containing all rotations
between pairs of dimensions.
blocks (torch.Tensor): A (n-1, n//2, 2) shaped tensor containing the indices that
specify rotations between pairs of dimensions. Each of the n-1 blocks contains n//2
pairs of independent rotations.
angles (torch.Tensor): A ``((n - 1) * n // 2,)`` shaped tensor containing all rotations
between pairs of dimensions.
blocks (torch.Tensor): A ``(n - 1, n // 2, 2)`` shaped tensor containing the indices that
specify rotations between pairs of dimensions. Each of the n-1 blocks contains n // 2
pairs of independent rotations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks.

"""
# Blocks is of shape (n_blocks, n/2, 2) containing indices for angles
# Within each block, each Givens rotation is commuting, so we can apply them in parallel
U = torch.eye(n, device=angles.device, dtype=angles.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slight preference to keep variables lower-case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this in the main GivensRotationLayer class. In the other code, I kept the capital letters just so that if someone is reading the algorithm in the paper alongside the code, each part of the algorithm is more easily understood.

Comment on lines +122 to +129
angles, blocks, Ufwd_saved = ctx.saved_tensors
Ufwd = Ufwd_saved.clone()
M = grad_output.t() # dL/dU, i.e., grad_output is of shape (n, n)
n = M.size(1)
block_size = n // 2
A = torch.zeros((block_size, n), device=angles.device, dtype=angles.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here re lowercase for Ufwd, M, and A. Avoids incorrect colour highlighting in themes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I didn't read this about the incorrect colour highlighting before I made my previous comment. I still think that it is easier to read the algorithm alongside the code if the use of lower/upper case match. For example, lower case m is usually used for an integer variable, not a tensor.

return U

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing return type hint.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the type hint as well as a longer explanation on what this return is.

U = self._create_rotation_matrix()
rotated_x = einsum(x, U, "... i, o i -> ... o")
if self.bias is not None:
rotated_x = rotated_x + self.bias
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
rotated_x = rotated_x + self.bias
rotated_x += self.bias

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

from einops import einsum


class NaiveGivensRotationLayer(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very keen on having a full on separate implementation here just to compare with/test the GivensRotationLayer. If this NaiveGivensRotationLayer is useful, should it be part of the package instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed this in our one on one but, just for the record, there is no difference between the NaiveGivensRotationLayer and the GivensRotationLayer in the forward or backward passes. The naïve implementation is there to make sure that the forward and backward passes indeed match. The GivensRotationLayer should always be used because it has a substantially better runtime complexity. Thus, the naïve implementation is not useful—other than for a sanity check.

tests/test_nn.py Outdated
Comment on lines 91 to 92
@parameterized.expand([(n, bias) for n in [4, 5, 6, 9, 10] for bias in [True, False]])
def test_forward_agreement(self, n, bias):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests do seem a bit too.. complex. Better to try and test more minimal aspects of the class, if possible. I'd much rather have separate integration-like tests that can assert that model behave as expected, while having these be strictly, small scale, isolated unit tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some tests to test invalid inputs too. These forward and backward tests are for testing that the correct input/output is given when compared to the naïve implementation. The model_probably_good test is done as unit test.

@VolodyaCO
Copy link
Collaborator Author

After a bit of git wrangling, I was able to clean my whole mess of merge commits 😆.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants